#include <gtest/gtest.h>
#include <iostream>
#define protected public
#define private public
#include "framework/graph/core/op/op_desc.h"
#include "graph/op/array_defs.h"
#include "graph/tensor.h"
#include "graph/attr_value.h"

#include "graph/persistance/interface/attr_def.h"
#include "graph/persistance/interface/attr_map_def.h"
#include "graph/persistance/proxy/proto_factory.h"
#include "graph/core/op/op_desc_factory.h"
#include "graph/core/op/constholder_op_desc.h"
#undef protected
#undef private
using namespace std;
using namespace ge;

class GeConstHolderOpDescUnittest : public testing::Test {
protected:
    void SetUp()
    {
    }

    void TearDown()
    {
    }
};

TEST_F(GeConstHolderOpDescUnittest, GeConstHolderOpDescUnittest_common)
{
    string name = "test_const";
    string type = "Const";
    OpDescPtr opDesc = std::shared_ptr<ConstHolderOpDesc>(new (std::nothrow) ConstHolderOpDesc(name, type));
    EXPECT_TRUE(opDesc != nullptr);
    TensorDesc tedesc1(Shape({1, 2, 3, 4}), FORMAT_NCHW, DT_FLOAT);
    EXPECT_EQ(GRAPH_SUCCESS, opDesc->AddOutputDesc(tedesc1));
    EXPECT_EQ(ConstHolderOpDesc::IsConstHolderOp(*opDesc), true);
    EXPECT_EQ(name, opDesc->GetName());
    EXPECT_EQ(type, opDesc->GetType());
    name = name + "_modify";
    type = type + "_modify";
    opDesc->SetName(name);
    opDesc->SetType(type);
    EXPECT_EQ(name, opDesc->GetName());
    EXPECT_EQ(type, opDesc->GetType());
}

TEST_F(GeConstHolderOpDescUnittest, GeConstHolderOpDescUnittest_GetOutputs_WithoutHoldedOpDesc)
{
    string name = "test_const";
    string type = "Const";
    ConstHolderOpDescPtr opDesc = std::shared_ptr<ConstHolderOpDesc>(new (std::nothrow) ConstHolderOpDesc(name, type));
    EXPECT_TRUE(opDesc != nullptr);
    TensorDesc tedesc1(Shape({1, 2, 3, 4}), FORMAT_NCHW, DT_FLOAT);
    EXPECT_EQ(GRAPH_SUCCESS, opDesc->AddOutputDesc(tedesc1));

    const TensorDesc& getTe1 = opDesc->GetOutputDesc(0);
    EXPECT_EQ(getTe1.GetShape().GetDim(0), 1);
    EXPECT_EQ(getTe1.GetShape().GetDim(1), 2);
    EXPECT_EQ(getTe1.GetShape().GetDim(2), 3);
    EXPECT_EQ(getTe1.GetShape().GetDim(3), 4);

    auto getTe1ptr = opDesc->GetOutputDescPtr(0);
    EXPECT_EQ(getTe1ptr->GetShape().GetDim(0), 1);
    EXPECT_EQ(getTe1ptr->GetShape().GetDim(1), 2);
    EXPECT_EQ(getTe1ptr->GetShape().GetDim(2), 3);
    EXPECT_EQ(getTe1ptr->GetShape().GetDim(3), 4);

    TensorDescPtr getTe2ptr = opDesc->MutableOutputDesc(0);
    EXPECT_TRUE(getTe2ptr != nullptr);
    EXPECT_EQ(getTe2ptr->GetShape().GetDim(0), 1);
    EXPECT_EQ(getTe2ptr->GetShape().GetDim(1), 2);
    EXPECT_EQ(getTe2ptr->GetShape().GetDim(2), 3);
    EXPECT_EQ(getTe2ptr->GetShape().GetDim(3), 4);

    vector<TensorDescPtr> vector_out = opDesc->GetOutputsDesc();
    EXPECT_EQ(vector_out.size(), 1);
    auto vistor_in = opDesc->GetAllOutputsDesc();
    EXPECT_EQ(vistor_in.size(), 1);
    auto vistor_in_ptr = opDesc->GetAllOutputsDescPtr();
    EXPECT_EQ(vistor_in_ptr.size(), 1);
    EXPECT_EQ(opDesc->GetOutputsSize(), 1);
}

TEST_F(GeConstHolderOpDescUnittest, GeConstHolderOpDescUnittest_GetOutputs_WithHoldedOpDesc)
{
    string name = "test_const";
    string type = "Const";
    ConstHolderOpDescPtr opDesc = std::shared_ptr<ConstHolderOpDesc>(new (std::nothrow) ConstHolderOpDesc(name, type));
    EXPECT_TRUE(opDesc != nullptr);
    TensorDesc tedesc1(Shape({1, 2, 3, 4}), FORMAT_NCHW, DT_FLOAT);
    EXPECT_EQ(GRAPH_SUCCESS, opDesc->AddOutputDesc(tedesc1));

    const TensorDesc& getTe1 = opDesc->GetOutputDesc(0);
    EXPECT_EQ(getTe1.GetShape().GetDim(0), 1);
    EXPECT_EQ(getTe1.GetShape().GetDim(1), 2);
    EXPECT_EQ(getTe1.GetShape().GetDim(2), 3);
    EXPECT_EQ(getTe1.GetShape().GetDim(3), 4);

    OpDescPtr opDesc2 = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, type));
    EXPECT_TRUE(opDesc2 != nullptr);
    TensorDesc tedesc2(Shape({5, 6, 7, 8}), FORMAT_NCHW, DT_FLOAT);
    EXPECT_EQ(GRAPH_SUCCESS, opDesc2->AddOutputDesc(tedesc2));
    opDesc->SetHoldedOpDesc(opDesc2.get());

    const TensorDesc& getTe2 = opDesc->GetOutputDesc(0);
    EXPECT_EQ(getTe2.GetShape().GetDim(0), 5);
    EXPECT_EQ(getTe2.GetShape().GetDim(1), 6);
    EXPECT_EQ(getTe2.GetShape().GetDim(2), 7);
    EXPECT_EQ(getTe2.GetShape().GetDim(3), 8);

    auto getTe2ptr = opDesc->GetOutputDescPtr(0);
    EXPECT_EQ(getTe2ptr->GetShape().GetDim(0), 5);
    EXPECT_EQ(getTe2ptr->GetShape().GetDim(1), 6);
    EXPECT_EQ(getTe2ptr->GetShape().GetDim(2), 7);
    EXPECT_EQ(getTe2ptr->GetShape().GetDim(3), 8);

    TensorDescPtr getTe3ptr = opDesc->MutableOutputDesc(0);
    EXPECT_TRUE(getTe3ptr != nullptr);
    EXPECT_EQ(getTe3ptr->GetShape().GetDim(0), 5);
    EXPECT_EQ(getTe3ptr->GetShape().GetDim(1), 6);
    EXPECT_EQ(getTe3ptr->GetShape().GetDim(2), 7);
    EXPECT_EQ(getTe3ptr->GetShape().GetDim(3), 8);

    vector<TensorDescPtr> vector_out = opDesc->GetOutputsDesc();
    EXPECT_EQ(vector_out.size(), 1);
    auto vistor_in = opDesc->GetAllOutputsDesc();
    EXPECT_EQ(vistor_in.size(), 1);
    auto vistor_in_ptr = opDesc->GetAllOutputsDescPtr();
    EXPECT_EQ(vistor_in_ptr.size(), 1);
    EXPECT_EQ(opDesc->GetOutputsSize(), 1);
}

TEST_F(GeConstHolderOpDescUnittest, GeConstHolderOpDescUnittest_SetHoldedOpDesc)
{
    string name = "test_const";
    string type = "Const";
    ConstHolderOpDescPtr opDesc = std::shared_ptr<ConstHolderOpDesc>(new (std::nothrow) ConstHolderOpDesc(name, type));
    EXPECT_TRUE(opDesc != nullptr);
    TensorDesc tedesc1(Shape({1, 2, 3, 4}), FORMAT_NCHW, DT_FLOAT);
    EXPECT_EQ(GRAPH_SUCCESS, opDesc->AddOutputDesc(tedesc1));

    opDesc->SetHoldedOpDesc(nullptr);
    EXPECT_EQ(opDesc->GetHoldedOpDesc(), nullptr);

    OpDescPtr opDesc2 = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, type));
    EXPECT_TRUE(opDesc2 != nullptr);
    EXPECT_EQ(GRAPH_SUCCESS, opDesc2->AddOutputDesc(tedesc1));

    opDesc->SetHoldedOpDesc(opDesc2.get());
    EXPECT_EQ(opDesc->GetHoldedOpDesc(), opDesc2.get());
}

TEST_F(GeConstHolderOpDescUnittest, GeConstHolderOpDescUnittest_Clone)
{
    string name = "test_const";
    string type = "Const";
    ConstHolderOpDescPtr opDesc = std::shared_ptr<ConstHolderOpDesc>(new (std::nothrow) ConstHolderOpDesc(name, type));
    EXPECT_TRUE(opDesc != nullptr);
    TensorDesc tedesc1(Shape({1, 2, 3, 4}), FORMAT_NCHW, DT_FLOAT);
    EXPECT_EQ(opDesc->AddOutputDesc(tedesc1), GRAPH_SUCCESS);

    OpDescPtr opDesc2 = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, type));
    EXPECT_TRUE(opDesc2 != nullptr);
    EXPECT_EQ(opDesc2->AddOutputDesc(tedesc1), GRAPH_SUCCESS);

    opDesc->SetHoldedOpDesc(opDesc2.get());
    EXPECT_EQ(opDesc->GetHoldedOpDesc(), opDesc2.get());

    OpDescPtr opDesc3 = opDesc->Clone();
    EXPECT_TRUE(opDesc3 != nullptr);
    EXPECT_TRUE(opDesc3->HasAttr("is_const_holder"));
    ConstHolderOpDescPtr constHolderOp = std::static_pointer_cast<ConstHolderOpDesc>(opDesc3);
    EXPECT_TRUE(constHolderOp != nullptr);
    EXPECT_EQ(constHolderOp->GetHoldedOpDesc(), opDesc2.get());
}

TEST_F(GeConstHolderOpDescUnittest, GeConstHolderOpDescUnittest_CreateConstHolderOpDesc)
{
    string name = "test_const";
    string type = "Const";
    auto constHolderOpDef = hiai::ProtoFactory::Instance()->CreateOpDef();
    EXPECT_TRUE(constHolderOpDef != nullptr);
    hiai::IAttrDef* valueDef = constHolderOpDef->mutable_attr()->add_attr("is_const_holder");
    valueDef->set_b(true);
    OpDescPtr opDesc = OpDescFactory::GetInstance().Create(constHolderOpDef);
    EXPECT_TRUE(opDesc != nullptr);
    EXPECT_TRUE(opDesc->HasAttr("is_const_holder"));

    ConstHolderOpDesc opDesc2(constHolderOpDef, true);
    EXPECT_TRUE(opDesc2.HasAttr("is_const_holder"));
}